{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "05ab88b2",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-14T10:45:41.580218Z",
     "start_time": "2023-05-14T10:45:40.384100Z"
    }
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "from torch.nn import functional as F"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15d072b9",
   "metadata": {},
   "source": [
    "## embedding 的计算过程"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "95297d1a",
   "metadata": {},
   "source": [
    "- embedding module 的前向过程其实是一个索引（查表）的过程\n",
    "    - 表的形式是一个 matrix（embedding.weight, learnable parameters）\n",
    "        - matrix.shape: (v, h)\n",
    "            - v：vocabulary size\n",
    "            - h：hidden dimension\n",
    "    - 具体索引的过程，是通过 one hot + 矩阵乘法的形式实现的；\n",
    "    - input.shape: (b, s)\n",
    "        - b：batch size\n",
    "        - s：seq len\n",
    "    - embedding(input)\n",
    "        - (b, s) ==> (b, s, h)\n",
    "        - (b, s) 和 (v, h) ==>? (b, s, h)\n",
    "            - (b, s) 经过 one hot => (b, s, v)\n",
    "            - (b, s, v) @ (v, h) ==> (b, s, h)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "76ac5217",
   "metadata": {},
   "source": [
    "### 简单前向"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "8ddb05ac",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-14T10:48:55.246351Z",
     "start_time": "2023-05-14T10:48:55.240987Z"
    }
   },
   "outputs": [],
   "source": [
    "# an Embedding module containing 10 tensors of size 3\n",
    "embedding = nn.Embedding(10, 3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "b1a83e3b",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-14T10:48:56.586227Z",
     "start_time": "2023-05-14T10:48:56.567557Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Parameter containing:\n",
       "tensor([[ 0.8376,  0.6068,  1.7555],\n",
       "        [ 0.4941,  0.1717, -0.2396],\n",
       "        [-1.8685,  1.2610, -0.5606],\n",
       "        [ 0.8324,  1.0663,  1.2586],\n",
       "        [-0.7126, -0.8973, -2.2054],\n",
       "        [ 0.7383,  0.2399,  0.1330],\n",
       "        [-1.3319, -0.5330,  0.9591],\n",
       "        [ 0.7808, -0.2259,  0.1930],\n",
       "        [ 1.1298,  0.1678,  1.1490],\n",
       "        [-0.6612, -0.9927, -0.4817]], requires_grad=True)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "embedding.weight"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "8ee6e671",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-14T10:49:21.901296Z",
     "start_time": "2023-05-14T10:49:21.892023Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.int64\n"
     ]
    }
   ],
   "source": [
    "# a batch of 2 samples of 4 indices each\n",
    "# b, s, \n",
    "input = torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]])\n",
    "print(input.dtype)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "944a057a",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-14T10:49:30.980544Z",
     "start_time": "2023-05-14T10:49:30.968113Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[ 0.4941,  0.1717, -0.2396],\n",
       "         [-1.8685,  1.2610, -0.5606],\n",
       "         [-0.7126, -0.8973, -2.2054],\n",
       "         [ 0.7383,  0.2399,  0.1330]],\n",
       "\n",
       "        [[-0.7126, -0.8973, -2.2054],\n",
       "         [ 0.8324,  1.0663,  1.2586],\n",
       "         [-1.8685,  1.2610, -0.5606],\n",
       "         [-0.6612, -0.9927, -0.4817]]], grad_fn=<EmbeddingBackward0>)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# (b, s, ) => (b, s, h)\n",
    "embedding(input)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a319ed45",
   "metadata": {},
   "source": [
    "### one-hot 矩阵乘法"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "71db7a51",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-14T10:50:36.991469Z",
     "start_time": "2023-05-14T10:50:36.979393Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([2, 4, 10])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([[[0, 1, 0, 0, 0, 0, 0, 0, 0, 0],\n",
       "         [0, 0, 1, 0, 0, 0, 0, 0, 0, 0],\n",
       "         [0, 0, 0, 0, 1, 0, 0, 0, 0, 0],\n",
       "         [0, 0, 0, 0, 0, 1, 0, 0, 0, 0]],\n",
       "\n",
       "        [[0, 0, 0, 0, 1, 0, 0, 0, 0, 0],\n",
       "         [0, 0, 0, 1, 0, 0, 0, 0, 0, 0],\n",
       "         [0, 0, 1, 0, 0, 0, 0, 0, 0, 0],\n",
       "         [0, 0, 0, 0, 0, 0, 0, 0, 0, 1]]])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# num_classes == vocab size\n",
    "# (b, s) => (b, s, v)\n",
    "input_onehot = F.one_hot(input, num_classes=10)\n",
    "print(input_onehot.shape)\n",
    "input_onehot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "d296d685",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-14T10:51:46.140016Z",
     "start_time": "2023-05-14T10:51:46.135504Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.float32\n",
      "torch.Size([10, 3])\n"
     ]
    }
   ],
   "source": [
    "print(embedding.weight.dtype)\n",
    "print(embedding.weight.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "17f72f05",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-14T10:52:32.464582Z",
     "start_time": "2023-05-14T10:52:32.451993Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[ 0.4941,  0.1717, -0.2396],\n",
       "         [-1.8685,  1.2610, -0.5606],\n",
       "         [-0.7126, -0.8973, -2.2054],\n",
       "         [ 0.7383,  0.2399,  0.1330]],\n",
       "\n",
       "        [[-0.7126, -0.8973, -2.2054],\n",
       "         [ 0.8324,  1.0663,  1.2586],\n",
       "         [-1.8685,  1.2610, -0.5606],\n",
       "         [-0.6612, -0.9927, -0.4817]]], grad_fn=<UnsafeViewBackward0>)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# input_onehot.shape: (b, s, v)\n",
    "# embedding.weight.shape: (v, h)\n",
    "# (b, s, h)\n",
    "torch.matmul(input_onehot.type(torch.float32), embedding.weight)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c1358c26",
   "metadata": {},
   "source": [
    "## max_norm"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "635fab40",
   "metadata": {},
   "source": [
    "- https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html\n",
    "- When max_norm is not None, Embedding’s forward method will modify the weight tensor in-place. Since tensors needed for gradient computations cannot be modified in-place, performing a differentiable operation on Embedding.weight before calling Embedding’s forward method requires cloning Embedding.weight when max_norm is not None. For example:"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9d2ddcf9",
   "metadata": {},
   "source": [
    "### 不设置 max_norm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "8ba6b22a",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-14T10:54:53.793466Z",
     "start_time": "2023-05-14T10:54:53.779138Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(0.0240, grad_fn=<MeanBackward0>)\n",
      "tensor(0.6895, grad_fn=<StdBackward0>)\n",
      "Parameter containing:\n",
      "tensor([[ 0.7605,  0.3294, -0.3857,  0.6366,  1.1262],\n",
      "        [-0.3753, -0.5345, -0.0094,  0.3832, -0.1250],\n",
      "        [-1.0760,  0.1818,  1.0244, -0.5830, -0.9929]], requires_grad=True)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([1.5840, 0.7675, 1.8884], grad_fn=<LinalgVectorNormBackward0>)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# max_norm == True ==> max_norm == 1\n",
    "embedding = nn.Embedding(3, 5,)\n",
    "print(embedding.weight.mean())\n",
    "print(embedding.weight.std())\n",
    "print(embedding.weight)\n",
    "torch.norm(embedding.weight, dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "5cd9d264",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-14T10:55:45.617026Z",
     "start_time": "2023-05-14T10:55:45.603929Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([3])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.7605,  0.3294, -0.3857,  0.6366,  1.1262],\n",
       "        [-0.3753, -0.5345, -0.0094,  0.3832, -0.1250],\n",
       "        [-1.0760,  0.1818,  1.0244, -0.5830, -0.9929]],\n",
       "       grad_fn=<EmbeddingBackward0>)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inputs = torch.tensor([0, 1, 2])\n",
    "print(inputs.shape)\n",
    "outputs = embedding(inputs)\n",
    "outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "9e8b9b77",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-14T10:55:53.771320Z",
     "start_time": "2023-05-14T10:55:53.759539Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([1.5840, 0.7675, 1.8884], grad_fn=<LinalgVectorNormBackward0>)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.norm(embedding.weight, dim=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bf163010",
   "metadata": {},
   "source": [
    "### max_norm=True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "cc4dd4ac",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-14T10:56:23.781568Z",
     "start_time": "2023-05-14T10:56:23.764518Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(0.1015, grad_fn=<MeanBackward0>)\n",
      "tensor(1.2156, grad_fn=<StdBackward0>)\n",
      "Parameter containing:\n",
      "tensor([[ 0.7330, -0.2748,  0.5157, -2.5154,  1.3099],\n",
      "        [ 0.4129, -0.4855,  2.0899, -0.1314,  1.9055],\n",
      "        [-0.6868, -0.3076, -0.8595, -1.1350,  0.9513]], requires_grad=True)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([2.9869, 2.9021, 1.8704], grad_fn=<LinalgVectorNormBackward0>)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# max_norm == True ==> max_norm == 1\n",
    "embedding = nn.Embedding(3, 5, max_norm=True)\n",
    "print(embedding.weight.mean())\n",
    "print(embedding.weight.std())\n",
    "print(embedding.weight)\n",
    "torch.norm(embedding.weight, dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "1635aa29",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-14T10:56:40.214411Z",
     "start_time": "2023-05-14T10:56:40.201834Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([3])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.2454, -0.0920,  0.1727, -0.8421,  0.4385],\n",
       "        [ 0.1423, -0.1673,  0.7201, -0.0453,  0.6566],\n",
       "        [-0.3672, -0.1644, -0.4596, -0.6068,  0.5086]],\n",
       "       grad_fn=<EmbeddingBackward0>)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inputs = torch.tensor([0, 1, 2])\n",
    "print(inputs.shape)\n",
    "outputs = embedding(inputs)\n",
    "outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "b50b73ba",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-14T10:57:05.627148Z",
     "start_time": "2023-05-14T10:57:05.618671Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([1.0000, 1.0000, 1.0000], grad_fn=<LinalgVectorNormBackward0>)"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.norm(outputs, dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "8f5b7826",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-14T10:57:13.194777Z",
     "start_time": "2023-05-14T10:57:13.183047Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Parameter containing:\n",
       "tensor([[ 0.2454, -0.0920,  0.1727, -0.8421,  0.4385],\n",
       "        [ 0.1423, -0.1673,  0.7201, -0.0453,  0.6566],\n",
       "        [-0.3672, -0.1644, -0.4596, -0.6068,  0.5086]], requires_grad=True)"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "embedding.weight"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "14dc417b",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-14T10:57:20.636525Z",
     "start_time": "2023-05-14T10:57:20.624072Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([1.0000, 1.0000, 1.0000], grad_fn=<LinalgVectorNormBackward0>)"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.norm(embedding.weight, dim=1)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {
    "height": "calc(100% - 180px)",
    "left": "10px",
    "top": "150px",
    "width": "227px"
   },
   "toc_section_display": true,
   "toc_window_display": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
